from copy import deepcopy
from functools import wraps
from inspect import getcallargs, signature
import logging
from typing import TypeAlias, Literal, overload, Optional, Callable, Any, ParamSpec, TypeVar, Union

import arcpy
import arcgis.geometry as ag

from ng911ok.lib.gdbsession import GPMessenger

_logger = logging.getLogger(__name__)

_T = TypeVar("_T")
_P = ParamSpec("_P")

MeasurementMethod: TypeAlias = Literal["GEODESIC", "PLANAR", "GREAT_ELLIPTIC", "LOXODROME", "PRESERVE_SHAPE"]
AnyPointGeom: TypeAlias = Union[arcpy.PointGeometry, ag.Point]
AnyMultiPoint: TypeAlias = Union[arcpy.Multipoint, ag.MultiPoint]
AnyPolyline: TypeAlias = Union[arcpy.Polyline, ag.Polyline]
AnyPolygon: TypeAlias = Union[arcpy.Polygon, ag.Polygon]
AnySR: TypeAlias = Union[arcpy.SpatialReference, ag.SpatialReference]
AnyGeometry = Union[arcpy.PointGeometry, ag.Point, arcpy.Multipoint, ag.MultiPoint, arcpy.Polyline, ag.Polyline, arcpy.Polygon, ag.Polygon, arcpy.SpatialReference, ag.SpatialReference]

T_PointGeom = TypeVar("T_PointGeom", arcpy.PointGeometry, ag.Point)
T_MultiPoint = TypeVar("T_MultiPoint", arcpy.Multipoint, ag.MultiPoint)
T_Polyline = TypeVar("T_Polyline", arcpy.Polyline, ag.Polyline)
T_Polygon = TypeVar("T_Polygon", arcpy.Polygon, ag.Polygon)
T_SR = TypeVar("T_SR", arcpy.SpatialReference, ag.SpatialReference)
T_Geometry = TypeVar("T_Geometry", arcpy.PointGeometry, ag.Point, arcpy.Multipoint, ag.MultiPoint, arcpy.Polyline, ag.Polyline, arcpy.Polygon, ag.Polygon, arcpy.SpatialReference, ag.SpatialReference)

_ARCPY_GEOM_TYPES: tuple[type[arcpy.Geometry | arcpy.SpatialReference], ...] = (arcpy.Geometry, arcpy.SpatialReference, arcpy.PointGeometry, arcpy.Multipoint, arcpy.Polyline, arcpy.Polygon)
_AG_GEOM_TYPES: tuple[type[ag.BaseGeometry], ...] = (ag.Geometry, ag.SpatialReference, ag.Point, ag.MultiPoint, ag.Polyline, ag.Polygon)


class InvalidGeometryError(ValueError):
    def __init__(self, geometry: AnyGeometry):
        geometry_str: str
        if isinstance(geometry, arcpy.Geometry):
            geometry_str = geometry.JSON
        elif isinstance(geometry, ag.BaseGeometry):
            geometry_str = str(geometry)
        else:
            _logger.error(f"Invalid type for argument 'geometry': {type(geometry)}")
            geometry_str = f"<Invalid argument provided for 'geometry'>:\n\t\t{geometry}"
        if len(geometry_str) > 1000:
            geometry_str = f"{geometry_str[:997]}..."
        super().__init__(f"Invalid geometry of type {type(geometry)}:\n\t{geometry_str}")
        self.geometry: AnyGeometry = geometry


def _get_corresponding_geom_type(geom_type: type[AnyGeometry]) -> type[AnyGeometry]:
    if geom_type in _ARCPY_GEOM_TYPES:
        return _AG_GEOM_TYPES[_ARCPY_GEOM_TYPES.index(geom_type)]
    elif geom_type in _AG_GEOM_TYPES:
        return _ARCPY_GEOM_TYPES[_AG_GEOM_TYPES.index(geom_type)]
    else:
        raise ValueError(f"No corresponding type for '{geom_type.__module__}.{geom_type.__name__}'.")

def _bind_and_call(func: Callable[_P, _T], /, arg_dict: dict[str, Any]) -> _T:
    """Calls a function with a ``dict`` of arguments, handling cases like
    positional-only parameters and variable-length (``*args``/``**kwargs``)
    parameters."""
    sig = signature(func)
    pos_args = []
    kw_args = {}
    for param_name, param in sig.parameters.items():
        match param.kind:
            case param.POSITIONAL_ONLY | param.POSITIONAL_OR_KEYWORD:
                pos_args.append(arg_dict[param_name])
            case param.VAR_POSITIONAL:
                pos_args += list(arg_dict[param_name])
            case param.KEYWORD_ONLY:
                kw_args[param_name] = arg_dict[param_name]
            case param.VAR_KEYWORD:
                kw_args |= arg_dict[param_name]
    bound_args = sig.bind(*pos_args, **kw_args)
    return func(*bound_args.args, **bound_args.kwargs)

def _cast_single_geometry(geom_object: AnyGeometry, target_type: type[T_Geometry]) -> T_Geometry:
    """Converts a single *geom_object* between corresponding ``arcpy`` and
    ``arcgis.geometry`` types."""
    if isinstance(geom_object, (arcpy.Geometry, arcpy.SpatialReference)):
        if target_type in _AG_GEOM_TYPES:
            return _get_corresponding_geom_type(type(geom_object))(geom_object)
        elif target_type in _ARCPY_GEOM_TYPES:
            return geom_object
        else:
            raise ValueError(f"target_type '{target_type.__name__}' is not allowed.")
    elif isinstance(geom_object, _AG_GEOM_TYPES):
        if target_type in _ARCPY_GEOM_TYPES:
            try:
                return geom_object.as_arcpy
            except ValueError as exc:
                if exc.args == ("Invalid geometry type for method", ):
                    raise InvalidGeometryError(geom_object) from exc
                else:
                    exc.add_note("This is a bug. Please report to the developers.")
                    raise exc
        elif target_type in _AG_GEOM_TYPES:
            return geom_object
        else:
            raise ValueError(f"target_type '{target_type.__name__}' is not allowed.")
    else:
        exc = RuntimeError(f"Wrong type passed to _cast_single_geometry(). This is a bug. Please report to the developers.")
        exc.add_note(f"geom_object [{type(geom_object)}]: {geom_object}")
        _logger.critical("Wrong type passed to _cast_single_geometry().", exc_info=exc)
        raise exc

def cast_input_geometry(**targets: type[AnyGeometry]) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
    """
    Decorator factory to allow decorated functions to take geometry objects
    from either the ``arcpy`` or ``arcgis`` packages while ensuring that
    parameters specified in *targets* will always be of one specific type.

    For example, consider these functions::

        def hopefully_arcpy_polygon(polygon_obj: arcpy.Polygon | arcgis.geometry.Polygon) -> None:
            assert isinstance(polygon, arcpy.Polygon)

        @cast_input_geometry(polygon_obj=arcpy.Polygon)
        def definitely_arcpy_polygon(polygon_obj: arcpy.Polygon | arcgis.geometry.Polygon) -> None:
            assert isinstance(polygon, arcpy.Polygon)

    The function ``hopefully_arcpy_polygon`` will only succeed if it is called
    with an instance of ``arcpy.Polygon`` as the ``polygon_obj`` argument.
    The function ``definitely_arcpy_polygon``, however will succeed even if
    ``polygon_obj`` is an instance of ``arcgis.geometry.Polygon`` because of
    the decorator.

    :param targets: Mapping of parameter name(s) in the decorated function to
        the desired type of each argument
    :type targets: type[arcpy.Geometry | arcpy.SpatialReference | arcgis.geometry.BaseGeometry]
    :return: A decorator
    :rtype: Callable[[Callable], Callable]
    """
    def _cast_input_geometry_decorator(func):
        @wraps(func)
        def _wrapper(*args, **kwargs):
            func_args: dict[str, Any] = getcallargs(func, *args, **kwargs)
            for arg_name, target_type in targets.items():
                arg: AnyGeometry | tuple[AnyGeometry, ...] | None = func_args.get(arg_name, None)
                if isinstance(arg, tuple):
                    func_args[arg_name] = tuple(_cast_single_geometry(arg_item, target_type) for arg_item in arg)
                else:
                    func_args[arg_name] = _cast_single_geometry(arg, target_type)
            return _bind_and_call(func, func_args)
        return _wrapper
    return _cast_input_geometry_decorator

def cast_geometry(match_type_of: str,
                  /,
                  **targets: type[AnyGeometry]
                  ) -> Callable[[Callable[..., AnyGeometry]], Callable[..., AnyGeometry]]:
    """Similar to :func:`cast_input_geometry`, but also ensures that the return
    type of the decorated function matches that of the argument for the
    parameter named *match_type_of*."""
    if match_type_of not in targets:
        raise ValueError(f"Argument for match_type_of ('{match_type_of}') must appear as a key in targets.")
    def _cast_geometry_decorator(func: Callable[..., AnyGeometry]):
        @wraps(func)
        def _wrapper(*args, **kwargs):
            func_args: dict[str, Any] = getcallargs(func, *args, **kwargs)
            return_type: type[AnyGeometry] = type(func_args[match_type_of])
            for arg_name, target_type in targets.items():
                arg: AnyGeometry | tuple[AnyGeometry, ...] | None = func_args.get(arg_name, None)
                if isinstance(arg, tuple):
                    func_args[arg_name] = tuple(_cast_single_geometry(arg_item, target_type) for arg_item in arg)
                else:
                    func_args[arg_name] = _cast_single_geometry(arg, target_type)
            raw_result = _bind_and_call(func, func_args)
            result = _cast_single_geometry(raw_result, return_type)
            return result
        return _wrapper
    return _cast_geometry_decorator

@overload
def get_vertex(line: AnyPolyline, vertex_index: int) -> arcpy.PointGeometry: ...

@overload
def get_vertex(line: AnyPolyline, part_index: int, vertex_index: int) -> arcpy.PointGeometry: ...

@cast_input_geometry(line=arcpy.Polyline)
def get_vertex(line: AnyPolyline, *args: int, **kwargs: int) -> arcpy.PointGeometry:
    # if isinstance(line, ag.Polyline):  # No longer needed due to decorator
    #     line = line.as_arcpy
    assert isinstance(line, arcpy.Polyline)
    sr = line.spatialReference
    if len(args) == 2:
        part_index, vertex_index = args
    elif len(args) == 1 and "vertex_index" in kwargs:
        part_index = args[0]
        vertex_index = kwargs["vertex_index"]
    elif len(args) == 1 and "part_index" in kwargs:
        part_index = kwargs["part_index"]
        vertex_index = args[0]
    else:
        part_index = 0
        vertex_index = args[0] if len(args) == 1 else kwargs["vertex_index"]
    try:
        pg = arcpy.PointGeometry(line[part_index][vertex_index], sr)
    except IndexError as ex:
        is_multipart = line.isMultipart
        raise ex
    return pg

def get_end_from_curve(curve_dict: dict[Literal["a", "b", "c"], list]) -> list[float | None]:
    """Given a curve dict, returns the endpoint of the curve."""
    if len(curve_dict) != 1:
        raise ValueError
    curve_key, curve_params = next(iter(curve_dict.items()))
    if curve_key not in {"a", "b", "c"}:
        raise ValueError
    return curve_params[0]

@cast_geometry("line", line=ag.Polyline)
def get_segment(line: T_Polyline, part: int, start_vertex: int, end_vertex: int) -> T_Polyline:
    """
    Extracts and returns the portion of a line between two vertices. Supports
    true curves.

    :param line: The line feature
    :type line: T_Polyline
    :param part: Index of the part; should be 0 unless the feature is multipart
    :type part: int
    :param start_vertex: Index of the first vertex to include in the returned
        segment
    :type start_vertex: int
    :param end_vertex: Index of the last vertex to include in the returned
        segment; must be greater than *start_vertex*
    :type end_vertex: int
    :return: A single-part line feature
    :rtype: T_Polyline
    """
    line: ag.Polyline
    if start_vertex >= end_vertex:
        raise ValueError("Start vertex must be less than end vertex.")
    key = "paths" if "paths" in line else "curvePaths"
    paths: list = line[key]
    segment_paths: list = paths[part][start_vertex : end_vertex + 1]
    if isinstance(segment_paths[0], dict):
        # Segment starts at the end of a curve
        segment_paths[0] = get_end_from_curve(segment_paths[0])
    segment: ag.Polyline = deepcopy(line)
    segment[key] = [segment_paths]
    return segment

def get_segments(line: T_Polyline, part: int) -> list[T_Polyline]:
    """
    Splits a line feature at its vertices and returns a list of segments. If
    *line* is multipart, only one part's worth of segments is returned.
    Supports true curves.

    :param line: The line feature
    :type line: T_Polyline
    :param part: Index of the part; should be 0 unless the feature is multipart
    :type part: int
    :return: List of the segments that comprise the specified *part* of *line*
    :rtype: list[T_Polyline]
    """
    vertex_count: int
    match line:
        case arcpy.Polyline() as __line:
            vertex_count = len(__line.getPart(part))
        case ag.Polyline({"paths": [*paths]} | {"curvePaths": [*paths]}):
            vertex_count = len(paths[part])
        case _:
            raise ValueError(f"Could not determine vertex count.")
    segments: list[T_Polyline] = []
    for i in range(vertex_count - 1):
        segments.append(get_segment(line, part, i, i + 1))
    return segments

@cast_input_geometry(line_feature=arcpy.Polyline)
def get_midpoint(line_feature: AnyPolyline) -> arcpy.PointGeometry:
    """Returns an ``arcpy.PointGeometry`` at the midpoint of *line_feature*."""
    return line_feature.positionAlongLine(0.5, use_percentage=True)

@cast_input_geometry(line_feature=arcpy.Polyline)
def point_beside_line(line_feature: AnyPolyline,
                      side: Literal["LEFT", "RIGHT"],
                      distance_meters: float,
                      segment: int = 0,
                      part: int = 0,
                      method: MeasurementMethod = "GEODESIC",
                      return_as: type[T_PointGeom] = arcpy.PointGeometry,
                      messenger: Optional[GPMessenger] = None
                      ) -> T_PointGeom:
    """Returns a point 90 degrees to either the left or the right of the input
    *line_feature*, offset by *distance_meters*."""
    line_feature: arcpy.Polyline
    line_direction: float
    try:
        line_segment = get_segment(line_feature, part, segment, segment + 1)
    except InvalidGeometryError as exc:
        if messenger:
            messenger.addWarningMessage("Failed splitting line into segments.")
        exc.add_note(f"Part {part}, segment {segment}: {line_feature.JSON}")
        exc.add_note("Failed splitting line into segments.")
        raise exc
    midpoint = get_midpoint(line_segment)
    ag_line_segment = ag.Polyline(line_segment)
    start_point = get_vertex(line_segment, 0, 0)
    if "curvePaths" in ag_line_segment:
        point_minus_delta: arcpy.PointGeometry = line_segment.positionAlongLine(0.49999, use_percentage=True)
        point_plus_delta: arcpy.PointGeometry = line_segment.positionAlongLine(0.50001, use_percentage=True)
        line_direction, _ = point_minus_delta.angleAndDistanceTo(point_plus_delta, method)
    else:
        # Assume straight line, find angle from start point
        line_direction, _ = start_point.angleAndDistanceTo(line_segment.lastPoint, method)

    delta_angle: float = {"LEFT": -90., "RIGHT": 90.}[side]
    direction_toward_side: float = (line_direction + delta_angle) % 360
    result = midpoint.pointFromAngleAndDistance(direction_toward_side, distance_meters, method)
    return _cast_single_geometry(result, return_as)